{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.1"
    },
    "toc": {
      "nav_menu": {},
      "number_sections": true,
      "sideBar": true,
      "skip_h1_title": false,
      "title_cell": "Table of Contents",
      "title_sidebar": "Contents",
      "toc_cell": false,
      "toc_position": {},
      "toc_section_display": true,
      "toc_window_display": false
    },
    "colab": {
      "name": "softmax-regression.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "46118d448f2a4ef2acf8a482e65f446f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_e37bbe66ab2c478292c0088bdc0e79ad",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_af20edd4dadb4f64a6340cfce0abaa6c",
              "IPY_MODEL_64a3897557c34be2ac654e7320fb16ae"
            ]
          }
        },
        "e37bbe66ab2c478292c0088bdc0e79ad": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "af20edd4dadb4f64a6340cfce0abaa6c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_8fa6bf89ff6c42409b19c97209304b29",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_265a8c45e7214224a585bdabc8c7a4ab"
          }
        },
        "64a3897557c34be2ac654e7320fb16ae": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_a74a42321ef0483ba37eb373fb33a5af",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 9920512/? [00:20&lt;00:00, 17755371.26it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_94aa2605fc2e42198b479d0505cc2fde"
          }
        },
        "8fa6bf89ff6c42409b19c97209304b29": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "265a8c45e7214224a585bdabc8c7a4ab": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "a74a42321ef0483ba37eb373fb33a5af": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "94aa2605fc2e42198b479d0505cc2fde": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "af66af078a1b4efe897b584b31cfb015": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_439870d7d0554354be8be20b7a72fcfa",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_3e6de983389e4229aac183f6149f337d",
              "IPY_MODEL_f870c90aa1d542d08e1f09bcb79374e9"
            ]
          }
        },
        "439870d7d0554354be8be20b7a72fcfa": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "3e6de983389e4229aac183f6149f337d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_44617b44fdc94c019dada5bab5bf8856",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_9144a604abda45e2ab822677ede35b8a"
          }
        },
        "f870c90aa1d542d08e1f09bcb79374e9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_36f2bfccb4c24ccf8f99b0baeac4d18c",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 32768/? [00:00&lt;00:00, 64889.23it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_9abb6fc34981461bbbfcfa159b2a2aaa"
          }
        },
        "44617b44fdc94c019dada5bab5bf8856": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "9144a604abda45e2ab822677ede35b8a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "36f2bfccb4c24ccf8f99b0baeac4d18c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "9abb6fc34981461bbbfcfa159b2a2aaa": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "3a50697eee8840e5bf702ffbd08237ab": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_a900310b039f44cc91f2a006a9dc9099",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_5405cc7a5ce043eca53b049490a4a8c4",
              "IPY_MODEL_13222aaeb4f543909d849a4cf322d150"
            ]
          }
        },
        "a900310b039f44cc91f2a006a9dc9099": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "5405cc7a5ce043eca53b049490a4a8c4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_083358e7ae3a4a9eb2be1f1525986b9b",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_8f211aa6f9694c8aab02ede62c942ddc"
          }
        },
        "13222aaeb4f543909d849a4cf322d150": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_fb4316530246429e9d549ffba5f76ff0",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 1654784/? [00:00&lt;00:00, 4038901.91it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_f06767e668d74c9a80b18dfb6412efac"
          }
        },
        "083358e7ae3a4a9eb2be1f1525986b9b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "8f211aa6f9694c8aab02ede62c942ddc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "fb4316530246429e9d549ffba5f76ff0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "f06767e668d74c9a80b18dfb6412efac": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "413936db123c4778b42106d121110ca8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_9d3510d7af1f40c6aed4e9a9c94c696f",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_140b9765a33442ad8398c992997fa6f0",
              "IPY_MODEL_73e3ee2b3356402aa482c4cc62f00230"
            ]
          }
        },
        "9d3510d7af1f40c6aed4e9a9c94c696f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "140b9765a33442ad8398c992997fa6f0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_a61f4b35d0a541a0b9f4df09a93d59a0",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_755abfca669940d5be02c8598a43275e"
          }
        },
        "73e3ee2b3356402aa482c4cc62f00230": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_c0d8da7a3e4044e0afdc96de56136780",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 8192/? [00:00&lt;00:00, 60746.39it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_3b4ceeae7080474bab97d8f708cbdb46"
          }
        },
        "a61f4b35d0a541a0b9f4df09a93d59a0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "755abfca669940d5be02c8598a43275e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "c0d8da7a3e4044e0afdc96de56136780": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "3b4ceeae7080474bab97d8f708cbdb46": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/DeepSE/deeplearning-models/blob/master/pytorch_ipynb/basic-ml/softmax-regression.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J5u7ENBQFl3X",
        "colab_type": "text"
      },
      "source": [
        "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
        "- Author: Sebastian Raschka\n",
        "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x3_3aQ5yFnwb",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 86
        },
        "outputId": "77d60801-b3f4-4d44-efa7-2b8a2ff339ac"
      },
      "source": [
        "!pip install -q IPython\n",
        "!pip install -q ipykernel\n",
        "!pip install -q torch\n",
        "!pip install -q watermark\n",
        "!pip install -q matplotlib\n",
        "!pip install -q tensorwatch\n",
        "!pip install -q sklearn\n",
        "!pip install -q pandas\n",
        "!pip install -q pydot\n",
        "!pip install -q hiddenlayer\n",
        "!pip install -q graphviz"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\u001b[K     |████████████████████████████████| 194kB 2.8MB/s \n",
            "\u001b[K     |████████████████████████████████| 143kB 8.6MB/s \n",
            "\u001b[?25h  Building wheel for tensorwatch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for pydotz (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AKddULAnFl3a",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 121
        },
        "outputId": "8c15b381-5108-4e38-a3b7-cace2ce8d714"
      },
      "source": [
        "%load_ext watermark\n",
        "%watermark -a 'Sebastian Raschka' -v -p torch"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Sebastian Raschka \n",
            "\n",
            "CPython 3.6.9\n",
            "IPython 5.5.0\n",
            "\n",
            "torch 1.5.1+cu101\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GE6ECV25Fl3i",
        "colab_type": "text"
      },
      "source": [
        "- Runs on CPU or GPU (if available)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K4YD22BIFl3k",
        "colab_type": "text"
      },
      "source": [
        "# Model Zoo -- Softmax Regression"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NqfVDhXgFl3l",
        "colab_type": "text"
      },
      "source": [
        "Implementation of softmax regression (multinomial logistic regression)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1kZp4A8FFl3m",
        "colab_type": "text"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hgv6ijJwFl3n",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from torchvision import datasets\n",
        "from torchvision import transforms\n",
        "from torch.utils.data import DataLoader\n",
        "import torch.nn.functional as F\n",
        "import torch"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qWm3BxQmFl3s",
        "colab_type": "text"
      },
      "source": [
        "## Settings and Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4Tqsu8VcFl3t",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 441,
          "referenced_widgets": [
            "46118d448f2a4ef2acf8a482e65f446f",
            "e37bbe66ab2c478292c0088bdc0e79ad",
            "af20edd4dadb4f64a6340cfce0abaa6c",
            "64a3897557c34be2ac654e7320fb16ae",
            "8fa6bf89ff6c42409b19c97209304b29",
            "265a8c45e7214224a585bdabc8c7a4ab",
            "a74a42321ef0483ba37eb373fb33a5af",
            "94aa2605fc2e42198b479d0505cc2fde",
            "af66af078a1b4efe897b584b31cfb015",
            "439870d7d0554354be8be20b7a72fcfa",
            "3e6de983389e4229aac183f6149f337d",
            "f870c90aa1d542d08e1f09bcb79374e9",
            "44617b44fdc94c019dada5bab5bf8856",
            "9144a604abda45e2ab822677ede35b8a",
            "36f2bfccb4c24ccf8f99b0baeac4d18c",
            "9abb6fc34981461bbbfcfa159b2a2aaa",
            "3a50697eee8840e5bf702ffbd08237ab",
            "a900310b039f44cc91f2a006a9dc9099",
            "5405cc7a5ce043eca53b049490a4a8c4",
            "13222aaeb4f543909d849a4cf322d150",
            "083358e7ae3a4a9eb2be1f1525986b9b",
            "8f211aa6f9694c8aab02ede62c942ddc",
            "fb4316530246429e9d549ffba5f76ff0",
            "f06767e668d74c9a80b18dfb6412efac",
            "413936db123c4778b42106d121110ca8",
            "9d3510d7af1f40c6aed4e9a9c94c696f",
            "140b9765a33442ad8398c992997fa6f0",
            "73e3ee2b3356402aa482c4cc62f00230",
            "a61f4b35d0a541a0b9f4df09a93d59a0",
            "755abfca669940d5be02c8598a43275e",
            "c0d8da7a3e4044e0afdc96de56136780",
            "3b4ceeae7080474bab97d8f708cbdb46"
          ]
        },
        "outputId": "be3f666c-998c-43ab-bd98-4c62ced1aedc"
      },
      "source": [
        "##########################\n",
        "### SETTINGS\n",
        "##########################\n",
        "\n",
        "# Device\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "# Hyperparameters\n",
        "random_seed = 123\n",
        "learning_rate = 0.1\n",
        "num_epochs = 10\n",
        "batch_size = 256\n",
        "\n",
        "# Architecture\n",
        "num_features = 784\n",
        "num_classes = 10\n",
        "\n",
        "\n",
        "##########################\n",
        "### MNIST DATASET\n",
        "##########################\n",
        "\n",
        "train_dataset = datasets.MNIST(root='data', \n",
        "                               train=True, \n",
        "                               transform=transforms.ToTensor(),  \n",
        "                               download=True)\n",
        "\n",
        "test_dataset = datasets.MNIST(root='data', \n",
        "                              train=False, \n",
        "                              transform=transforms.ToTensor())\n",
        "\n",
        "\n",
        "train_loader = DataLoader(dataset=train_dataset, \n",
        "                          batch_size=batch_size, \n",
        "                          shuffle=True)\n",
        "\n",
        "test_loader = DataLoader(dataset=test_dataset, \n",
        "                         batch_size=batch_size, \n",
        "                         shuffle=False)\n",
        "\n",
        "\n",
        "# Checking the dataset\n",
        "for images, labels in train_loader:  \n",
        "    print('Image batch dimensions:', images.shape)\n",
        "    print('Image label dimensions:', labels.shape)\n",
        "    break"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "46118d448f2a4ef2acf8a482e65f446f",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "af66af078a1b4efe897b584b31cfb015",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "3a50697eee8840e5bf702ffbd08237ab",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "413936db123c4778b42106d121110ca8",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw\n",
            "Processing...\n",
            "\n",
            "\n",
            "\n",
            "Done!\n",
            "Image batch dimensions: torch.Size([256, 1, 28, 28])\n",
            "Image label dimensions: torch.Size([256])\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/pytorch/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C9jbyaFoFl3y",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "##########################\n",
        "### MODEL\n",
        "##########################\n",
        "\n",
        "class SoftmaxRegression(torch.nn.Module):\n",
        "\n",
        "    def __init__(self, num_features, num_classes):\n",
        "        super(SoftmaxRegression, self).__init__()\n",
        "        self.linear = torch.nn.Linear(num_features, num_classes)\n",
        "        \n",
        "        self.linear.weight.detach().zero_()\n",
        "        self.linear.bias.detach().zero_()\n",
        "        \n",
        "    def forward(self, x):\n",
        "        logits = self.linear(x)\n",
        "        probas = F.softmax(logits, dim=1)\n",
        "        return logits, probas\n",
        "\n",
        "model = SoftmaxRegression(num_features=num_features,\n",
        "                          num_classes=num_classes)\n",
        "\n",
        "model.to(device)\n",
        "\n",
        "##########################\n",
        "### COST AND OPTIMIZER\n",
        "##########################\n",
        "\n",
        "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  "
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bya760KnF-H-",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 165
        },
        "outputId": "28b266e2-1fb2-4345-f1e6-b23dcfb0586f"
      },
      "source": [
        "import hiddenlayer as hl\n",
        "hl.build_graph(model, torch.zeros([62, 28*28]))"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<hiddenlayer.graph.Graph at 0x7f3f48c938d0>"
            ],
            "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n -->\n<!-- Title: %3 Pages: 1 -->\n<svg width=\"315pt\" height=\"108pt\"\n viewBox=\"0.00 0.00 315.00 108.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 72)\">\n<title>%3</title>\n<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-72,36 -72,-72 243,-72 243,36 -72,36\"/>\n<!-- /outputs/3 -->\n<g id=\"node1\" class=\"node\">\n<title>/outputs/3</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"54,-36 0,-36 0,0 54,0 54,-36\"/>\n<text text-anchor=\"start\" x=\"14\" y=\"-15\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Linear</text>\n</g>\n<!-- /outputs/4 -->\n<g id=\"node2\" class=\"node\">\n<title>/outputs/4</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"171,-36 117,-36 117,0 171,0 171,-36\"/>\n<text text-anchor=\"start\" x=\"127\" y=\"-15\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Softmax</text>\n</g>\n<!-- /outputs/3&#45;&gt;/outputs/4 -->\n<g id=\"edge1\" class=\"edge\">\n<title>/outputs/3&#45;&gt;/outputs/4</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M54.1362,-18C69.6773,-18 89.4112,-18 106.4253,-18\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"106.6057,-21.5001 116.6056,-18 106.6056,-14.5001 106.6057,-21.5001\"/>\n<text text-anchor=\"middle\" x=\"85.5\" y=\"-21\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">62x10</text>\n</g>\n</g>\n</svg>\n"
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QmUIxlyQFl34",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "4e302a4a-e57f-4361-c7a3-2860a16edd70"
      },
      "source": [
        "# Manual seed for deterministic data loader\n",
        "torch.manual_seed(random_seed)\n",
        "\n",
        "\n",
        "def compute_accuracy(model, data_loader):\n",
        "    correct_pred, num_examples = 0, 0\n",
        "    \n",
        "    for features, targets in data_loader:\n",
        "        features = features.view(-1, 28*28).to(device)\n",
        "        targets = targets.to(device)\n",
        "        logits, probas = model(features)\n",
        "        _, predicted_labels = torch.max(probas, 1)\n",
        "        num_examples += targets.size(0)\n",
        "        correct_pred += (predicted_labels == targets).sum()\n",
        "        \n",
        "    return correct_pred.float() / num_examples * 100\n",
        "    \n",
        "\n",
        "for epoch in range(num_epochs):\n",
        "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
        "        \n",
        "        features = features.view(-1, 28*28).to(device)\n",
        "        targets = targets.to(device)\n",
        "            \n",
        "        ### FORWARD AND BACK PROP\n",
        "        logits, probas = model(features)\n",
        "        \n",
        "        # note that the PyTorch implementation of\n",
        "        # CrossEntropyLoss works with logits, not\n",
        "        # probabilities\n",
        "        cost = F.cross_entropy(logits, targets)\n",
        "        optimizer.zero_grad()\n",
        "        cost.backward()\n",
        "        \n",
        "        ### UPDATE MODEL PARAMETERS\n",
        "        optimizer.step()\n",
        "        \n",
        "        ### LOGGING\n",
        "        if not batch_idx % 50:\n",
        "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
        "                   %(epoch+1, num_epochs, batch_idx, \n",
        "                     len(train_dataset)//batch_size, cost))\n",
        "            \n",
        "    with torch.set_grad_enabled(False):\n",
        "        print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
        "              epoch+1, num_epochs, \n",
        "              compute_accuracy(model, train_loader)))"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 001/010 | Batch 000/234 | Cost: 2.3026\n",
            "Epoch: 001/010 | Batch 050/234 | Cost: 0.7941\n",
            "Epoch: 001/010 | Batch 100/234 | Cost: 0.5651\n",
            "Epoch: 001/010 | Batch 150/234 | Cost: 0.4603\n",
            "Epoch: 001/010 | Batch 200/234 | Cost: 0.4822\n",
            "Epoch: 001/010 training accuracy: 88.04%\n",
            "Epoch: 002/010 | Batch 000/234 | Cost: 0.4105\n",
            "Epoch: 002/010 | Batch 050/234 | Cost: 0.4415\n",
            "Epoch: 002/010 | Batch 100/234 | Cost: 0.4367\n",
            "Epoch: 002/010 | Batch 150/234 | Cost: 0.4289\n",
            "Epoch: 002/010 | Batch 200/234 | Cost: 0.3926\n",
            "Epoch: 002/010 training accuracy: 89.37%\n",
            "Epoch: 003/010 | Batch 000/234 | Cost: 0.4112\n",
            "Epoch: 003/010 | Batch 050/234 | Cost: 0.3579\n",
            "Epoch: 003/010 | Batch 100/234 | Cost: 0.3013\n",
            "Epoch: 003/010 | Batch 150/234 | Cost: 0.3258\n",
            "Epoch: 003/010 | Batch 200/234 | Cost: 0.4254\n",
            "Epoch: 003/010 training accuracy: 89.98%\n",
            "Epoch: 004/010 | Batch 000/234 | Cost: 0.3988\n",
            "Epoch: 004/010 | Batch 050/234 | Cost: 0.3690\n",
            "Epoch: 004/010 | Batch 100/234 | Cost: 0.3459\n",
            "Epoch: 004/010 | Batch 150/234 | Cost: 0.4030\n",
            "Epoch: 004/010 | Batch 200/234 | Cost: 0.3240\n",
            "Epoch: 004/010 training accuracy: 90.35%\n",
            "Epoch: 005/010 | Batch 000/234 | Cost: 0.3265\n",
            "Epoch: 005/010 | Batch 050/234 | Cost: 0.3673\n",
            "Epoch: 005/010 | Batch 100/234 | Cost: 0.3085\n",
            "Epoch: 005/010 | Batch 150/234 | Cost: 0.3183\n",
            "Epoch: 005/010 | Batch 200/234 | Cost: 0.3316\n",
            "Epoch: 005/010 training accuracy: 90.64%\n",
            "Epoch: 006/010 | Batch 000/234 | Cost: 0.4518\n",
            "Epoch: 006/010 | Batch 050/234 | Cost: 0.3863\n",
            "Epoch: 006/010 | Batch 100/234 | Cost: 0.3620\n",
            "Epoch: 006/010 | Batch 150/234 | Cost: 0.3733\n",
            "Epoch: 006/010 | Batch 200/234 | Cost: 0.3289\n",
            "Epoch: 006/010 training accuracy: 90.86%\n",
            "Epoch: 007/010 | Batch 000/234 | Cost: 0.3450\n",
            "Epoch: 007/010 | Batch 050/234 | Cost: 0.2289\n",
            "Epoch: 007/010 | Batch 100/234 | Cost: 0.3073\n",
            "Epoch: 007/010 | Batch 150/234 | Cost: 0.2750\n",
            "Epoch: 007/010 | Batch 200/234 | Cost: 0.3456\n",
            "Epoch: 007/010 training accuracy: 91.00%\n",
            "Epoch: 008/010 | Batch 000/234 | Cost: 0.4900\n",
            "Epoch: 008/010 | Batch 050/234 | Cost: 0.3479\n",
            "Epoch: 008/010 | Batch 100/234 | Cost: 0.2343\n",
            "Epoch: 008/010 | Batch 150/234 | Cost: 0.3059\n",
            "Epoch: 008/010 | Batch 200/234 | Cost: 0.3684\n",
            "Epoch: 008/010 training accuracy: 91.22%\n",
            "Epoch: 009/010 | Batch 000/234 | Cost: 0.3762\n",
            "Epoch: 009/010 | Batch 050/234 | Cost: 0.2976\n",
            "Epoch: 009/010 | Batch 100/234 | Cost: 0.2690\n",
            "Epoch: 009/010 | Batch 150/234 | Cost: 0.2610\n",
            "Epoch: 009/010 | Batch 200/234 | Cost: 0.3140\n",
            "Epoch: 009/010 training accuracy: 91.34%\n",
            "Epoch: 010/010 | Batch 000/234 | Cost: 0.2790\n",
            "Epoch: 010/010 | Batch 050/234 | Cost: 0.3070\n",
            "Epoch: 010/010 | Batch 100/234 | Cost: 0.3300\n",
            "Epoch: 010/010 | Batch 150/234 | Cost: 0.2520\n",
            "Epoch: 010/010 | Batch 200/234 | Cost: 0.3301\n",
            "Epoch: 010/010 training accuracy: 91.40%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "U1mrm2smFl39",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "507cbcd7-7e6f-4048-ee71-48fdc78530ea"
      },
      "source": [
        "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Test accuracy: 91.77%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Jmb6J6x7Fl4C",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 52
        },
        "outputId": "8e41e2cd-3f3a-4417-e20b-408879326d3f"
      },
      "source": [
        "%watermark -iv"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "torch 1.5.1+cu101\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}